import streamlit as st
from operator import itemgetter
import gc

from langchain_ollama.chat_models import ChatOllama
from langchain_core.output_parsers import StrOutputParser

import whisper
import random

from utils.embedding_model import init_embedding
from utils.vector_store import get_vector_store
from utils.unique_union import get_unique_union

from prompts.multi_query import multi_query_prompt, faq_mp_rag_prompt
# from prompts.rag import simple_faq_prompt


# RTL support for Persian text
st.markdown("""
<style>
    html, body, [class*="css"] {
        direction: rtl;
        text-align: center;
    }
    button, input, optgroup, select, textarea {
        direction: rtl;
    }
    .stButton button {
        direction: rtl;
    }
    header {
        direction: rtl;
    }
    .main {
        direction: rtl;
    }
    .stChatInput {
        direction: rtl;
    }
</style>
""", unsafe_allow_html=True)

st.logo(image="static/images/bimika.jpg", size="large", link="http://localhost:8501")

@st.cache_resource
def init_models():
    embedding_model, _ = init_embedding(
        model_name="intfloat/multilingual-e5-large",
        model_kwargs = {"device": "cpu", "trust_remote_code": True},
        test_query = "Embedding Model Initialized Successfully.")

    retriever = get_vector_store(persist_directory="db/store_500", embedding_model=embedding_model, k_arg=1)

    llm = ChatOllama(
        model="7shi/ezo-common-gemma-2:9b-instruct-q4_K_M",
        temperature=0.0,
        verbose=True,
        num_predict=256,
    )

    prompt = multi_query_prompt()

    generate_queries = (
        prompt
        | llm
        | StrOutputParser()
        | (lambda x: x.split("\n"))
    )
    retrieval_chain = generate_queries | retriever.map() | get_unique_union

    # prompt = simple_faq_prompt()
    prompt = faq_mp_rag_prompt()

    rag_chain = (
        {"context": retrieval_chain,
        "question": itemgetter("question")}
        | prompt
        | llm
        | StrOutputParser()
    )
    asr_model = whisper.load_model("large-v3-turbo")

    return rag_chain, asr_model


rag_chain, asr_model = init_models()


st.title("چت بات هوشمند پاسخگویی به سوالات مشتریان")

if "history" not in st.session_state:
    st.session_state.history = []

for msg in st.session_state.history:
    st.chat_message(msg["role"]).markdown(msg["content"])

audio = st.audio_input("🤩 با صدای خودتان سوال بپرسید")
prompt = st.chat_input("😇 سوال خود را از دستیار هوشمند ما بپرسید")

if prompt:
    st.chat_message("user").markdown(prompt)
    st.session_state.history.append({"role": "user", "content": prompt})

    response = rag_chain.invoke({"question": prompt})

    st.chat_message("assistant").markdown(response)
    st.session_state.history.append({"role": "assistant", "content": response})

    gc.collect()


if audio:
    st.chat_message("assistant").markdown("... پردازش صوت")

    try:
        sample_name = str(random.randint(1, 100000))+"_test.wav"
    except Exception :
        sample_name = str(random.randint(1, 100000))+"_test.wav"

    with open(f"audio/{sample_name}", "wb") as f:
        f.write(audio.getbuffer())

    result = asr_model.transcribe(f"audio/{sample_name}")

    st.chat_message("user").markdown(result["text"])
    st.session_state.history.append({"role": "user", "content": result["text"]})

    # ToDo: Error handling if not Persian
    response = rag_chain.invoke({"question": result["text"]})

    st.chat_message("assistant").markdown(response)
    st.session_state.history.append({"role": "assistant", "content": response})

    gc.collect()
    # st.rerun()
